import numbers
import time

import numpy as np


class UpliftBandit(object):

    def __init__(self, n_arms, n_variables, ns_affected,
                 minimum_uplift=0,
                 rng_initialize=42, rng_feedback=7):
        self.n_arms = n_arms
        self.n_variables = n_variables
        if isinstance(ns_affected, numbers.Number):
            ns_affected = np.repeat(ns_affected, n_arms)
        self.minimum_uplift = minimum_uplift
        # Initialize random generators
        self.rng_initialize = np.random.default_rng(rng_initialize)
        self.rng_feedback = np.random.default_rng(rng_feedback)
        # Intitalize expected payoffs
        self.baseline = self.rng_initialize.uniform(
                low=minimum_uplift, high=1-minimum_uplift, size=n_variables)
        self.affected_sets = []
        self.means = np.tile(self.baseline, (n_arms, 1))
        for arm in range(n_arms):
            n_affected = ns_affected[arm]
            affected_set = self.rng_initialize.choice(
                n_variables, n_affected, replace=False)
            means_affected_baseline = self.baseline[np.ix_(affected_set)]
            means_affected = self.rng_initialize.random(n_affected)
            for i in range(len(affected_set)):
                b = means_affected_baseline[i]
                if means_affected[i] <= b:
                    means_affected[i] = min(means_affected[i], b - self.minimum_uplift)
                else:
                    means_affected[i] = max(means_affected[i], b + self.minimum_uplift)
            self.means[arm, np.ix_(affected_set)] = means_affected
            self.affected_sets.append(affected_set)
        self.compute_statistics()

    # Compute some basic quantities
    def compute_statistics(self):
        self.rewards = np.sum(self.means, axis=1)
        self.optimal_arm = np.argmax(self.rewards)
        self.optimal_expected_reward = np.max(self.rewards)
        self.uplifts = self.rewards - np.sum(self.baseline)
        rewards_sorted = np.sort(self.rewards)
        self.gap = rewards_sorted[-1] - rewards_sorted[-2]
        self.minimum_uplift = np.unique(np.abs(self.means - self.baseline))[1]

    def expected_reward(self, arm):
        return self.rewards[arm]

    def init_feedback(self, rng_feedback):
        self.rng_feedback = np.random.default_rng(rng_feedback)


class BernoulliUpliftBandit(UpliftBandit):

    def feedback(self, arm):
        return (self.rng_feedback.random(self.n_variables) < self.means[arm]).astype('float')


class GaussianUpliftBandit(UpliftBandit):

    def __init__(self, n_arms, n_variables, ns_affected,
                 minimum_uplift=0, A=None,
                 rng_initialize=42, rng_feedback=7, rng_covmat=253):
        if A is None:
            rng = np.random.default_rng(rng_covmat)
            self.A = rng.random((n_variables, n_variables))
        else:
            self.A = A
        AAT = self.A @ self.A.T
        self.scale = np.sqrt(1/np.max(np.diag(AAT)))
        self.cov_mat = self.scale**2 * AAT
        self.noise_std = np.sqrt(np.sum(self.cov_mat))
        super().__init__(n_arms, n_variables, ns_affected,
                         minimum_uplift, rng_initialize, rng_feedback)

    def feedback(self, arm):
        mean = self.means[arm]
        noise = self.A @ self.rng_feedback.normal(size=self.n_variables)
        return mean + self.scale * noise


# return simple regrets
def interact(bandit, learner, n_rounds, print_step=None):
    regrets = []
    arm_his = []
    start = time.time()
    for step in range(n_rounds):
        if print_step is not None and step % print_step == 0:
            print(step)
            print(f'time: {time.time()-start}')
        arm = learner.act()
        feedback = bandit.feedback(arm)
        learner.update(arm, feedback)
        arm_his.append(arm)
        regrets.append(bandit.optimal_expected_reward - bandit.expected_reward(arm))
    return regrets, arm_his
